6.2 Entscheidungsbäume visualisieren und trainieren#

Im letzten Kapitel haben wir gelernt, wie mit Scikit-Learn ein Entscheidungsbaum für binäre Klassifikationsaufgaben trainiert wird. In diesem Kapitel werden wir uns damit beschäftigen, den trainierten Entscheidungsbaum von Scikit-Learn visualisieren zu lassen. Darüber hinaus lernen wir, was das Gini-Impurity-Kriterion ist und welche weiteren Einstellmöglichkeiten es für Entscheidungsbäume in Scikit-Learn gibt.

Lernziele#

Lernziele

  • Sie können einen Entscheidungsbaum mit plot_tree visualisieren.

  • Sie wissen, was die Angaben samples und value bei der Visualisierung des Entscheidungsbaumes bedeuten.

  • Sie wissen, was das Gini-Impurity-Kriterium ist.

  • Sie kennen weitere Parameter für Entscheidungsbäume wie random_state= oder criterion=.

Entscheidungsbäume visualisieren#

Im letzten Kapitel haben wir den Entscheidungsbaum für das Autohaus mit Hilfe des Moduls Scikit-Learn trainiert. Scikit-Learn bietet in dem Untermodul sklearn.tree nicht nur Algorithmen für Entscheidungsbäume an, sondern auch ein dazu passendes Visualisierungswerkzeug. Die Funktion plot_tree zeichnet den Entscheidungsbaum. Um diese Funktion auszuprobieren, wird zunächst der Datensatz mit den Autodaten erneut geladen, das Modell Entscheidungsbaum gewählt und anschließend trainiert.

import pandas as pd 
from sklearn.tree import DecisionTreeClassifier

# Sammlung der Daten 
daten = pd.DataFrame({
    'Kilometerstand [km]': [32908, 20328, 13285, 17162, 27449, 13715, 32889,  3111, 15607, 18295],
    'Preis [EUR]': [15960, 20495, 17227, 17851, 5428, 22772, 13581, 16793, 23253, 11382],
    'verkauft': [False, True, False, True, False, True, False, True, True, False],
    },
    index=['Auto 1', 'Auto 2', 'Auto 3', 'Auto 4', 'Auto 5', 'Auto 6', 'Auto 7', 'Auto 8', 'Auto 9', 'Auto 10'])
daten.head(10)

# Auswahl des Modells: Entscheidungsbaum für Klassifikation
modell = DecisionTreeClassifier(random_state=0)

# Adaption der Daten
X = daten[['Kilometerstand [km]', 'Preis [EUR]']]
y = daten['verkauft']

# Training des Modells
modell.fit(X,y)
DecisionTreeClassifier(random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Nun können wir die Funktion plot_tree importieren und das trainierte Modell visualisieren lassen.

from sklearn.tree import plot_tree

plot_tree(modell)
[Text(0.4, 0.875, 'x[1] <= 16376.5\ngini = 0.5\nsamples = 10\nvalue = [5, 5]'),
 Text(0.2, 0.625, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'),
 Text(0.30000000000000004, 0.75, 'True  '),
 Text(0.6, 0.625, 'x[0] <= 13500.0\ngini = 0.278\nsamples = 6\nvalue = [1, 5]'),
 Text(0.5, 0.75, '  False'),
 Text(0.4, 0.375, 'x[0] <= 8198.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'),
 Text(0.2, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'),
 Text(0.6, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.8, 0.375, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]')]
../_images/d2f618036a3b3a893e65457386017892f7621a9d55f1e9c2468d9ad5684e224e.png

plot_tree produziert eine Textausgabe und ein Diagramm. Die Textausgabe kann unterdrückt werden, indem hinter den Funktionsaufruf plot_tree(modell) ein Semikolon ; gesetzt wird. Das Diagramm zeichnet wie erwartet die Baumstruktur vom Wurzelknoten über die Knoten und Zweige bis hin zu den Blättern. Die Entscheidungsfragen stehen in der ersten Zeile der Knoten. Danach folgen weitere Angaben wie gini, samples und value. Um diese Angaben zu erklären, ergänzen wir zunächst weitere Angaben. Mit der Option feature_names= wird eine Liste mit den Eigenschaften ergänzt, die Option class_names= ergänzt die Klassenbezeichnungen. So erhalten wir folgendes Diagramm:

plot_tree(modell, 
    feature_names=['Kilometerstand [km]', 'Preis [EUR]'],
    class_names=['nicht verkauft', 'verkauft']);
../_images/927a2d874b4de5666d05bf4129bedcb2b13d44191c06c7248d56a465ef6caa42.png

Was gini bedeuten könnte, erschließt sich so immer noch nicht, aber die Angaben samples und values können so leichter von ihrer Bedeutung her eingeordnet werden. samples gibt die Anzahl der Datenobjekte an, die sich in diesem Knoten befinden. values listet auf, wie viele Datenobjekte die Zielgröße nicht verkauft (= False bzw. 0) haben und wie viele zu der Klasse verkauft (= True bzw. 1) gehören.

Weitere Details zu den Optionen der plot_tree-Funktion finden Sie in der Dokumentation Scikit-Learn → plot_tree.

Als nächstes widmen wir uns der Bedeutung von gini.

Was ist das Gini-Impurity-Kriterium?#

Das Gini-Impurity-Kriterium ist ein Maß für die Unreinheit eines Datensatzes. Beim Beispiel mit dem Autohaus sind im Wurzelknoten fünf Autos, die nicht verkauft wurden, und fünf verkaufte Autos. Bei zwei Klassen mit je 50 % Anteil ist das die maximale Unreinheit, die auftreten kann. Der Anteil der verkauften Autos ist genau 50 %. Bei dieser Verteilung beträgt das Gini-Impurity-Kriterium 0.5. Es gibt zwei weitere Extremfälle. Entweder sind nur verkaufte Autos im Datensatz (100 % verkaufte Autos) oder gar keine verkaufte Autos (0 % verkaufte Autos). In beiden Fällen ist der Datensatz rein, das Gini-Impurity-Kriterium ist 0. In allen anderen Fällen liegt das Gini-Impurity-Kriterium zwischen 0 und 0.5. Die Formel zur Berechnung des genauen Wertes des Gini-Impurity-Kriteriums lautet

\[\text{GI} = 1 - p^2 - (1-p)^2,\]

wenn \(p\) der prozentuale Anteil der verkauften Autos ist (das gilt natürlich allgemein für binäre Klassifikationsaufgaben und nicht nur das Autohaus-Beispiel).

Die folgende Abbildung zeigt die konkreten Werte des Gini-Impurity-Kriteriums für den prozentualen Anteil an verkauften Autos.

from numpy import linspace

p = linspace(0,1)
gini = 1 - p**2 - (1-p)**2

import plotly.express as px

fig = px.line(x = p, y = gini,
        title='Gini-Impurity-Kriterium',
        labels={'x': 'prozentualer Anteil', 'y': 'Wert des Gini-Impurity-Kriteriums'})
fig.show()

Im Diagramm können wir direkt ablesen, dass bei einem nicht verkauften Auto und fünf verkauften Autos (\(p = 0.8\bar{3}\)) das Gini-Impurity-Kriterium den Wert \(0.27\bar{7} \approx 0.278\) hat.

Das Gini-Impurity-Kriterium ist sehr wichtig für das Training eines Entscheidungsbaumes. Der Algorithmus probiert im Hintergrund verschiedene Möglichkeiten durch, mit Hilfe der Entscheidungsfragen den Datensatz zu splitten. Zu jedem Split werden dann die zugehörigen Werte des Gini-Impurity-Kriteriums berechnet. Dann wählt der Algorithmus den Split aus, der die höchste Reinheit hat (also den niedrigsten Gini-Impurity-Wert). Gilt das für mehrere Splits, dann wird zufällig ein Split ausgewählt. Jedes Training kann daher zu einem anderen Entscheidungsbaum führen. Ist dieses Verhalten nicht gewünscht, kann der optionale Parameter random_state= auf einen Integer gesetzt werden, um die Zufallszahlen zu fixieren. Das haben wir auch bereits im vorherigen Kapitel gemacht, damit die Ergebnisse vergleichbar waren.

Neben dem Gini-Impurity-Kriterium gibt es noch weitere Bewertungsmaße, um einen Entscheidungsbaum zu trainieren. In Scikit-Learn sind die beiden Alternativen log_loss und entropy verfügbar, die auf der Shannon-Entropie basieren und den Informationsgewinn durch Splits maximieren. Wir schauen uns im Folgenden an, wie diese ausgewählt werden können. Wer zuvor sich noch ein wenig mehr mit den Details von Entscheidungsbäumen beschäftigen möchte, kann sich die folgenden Videos ansehen.

Optionales Video “Entscheidungsbäume #2 - Der ID3-Algorithmus” von The Morpheus Tutorials
Optionales Video “Entscheidungsbäume #3 - Entropie und Informationsgewinn” von The Morpheus Tutorials
Optionales Video “ID3 Entscheidungsbaum” von 42 Entwickler

Entscheidungsbäume trainieren#

Der Entscheidungsbaum-Klassifikationsalgorithmus von Scikit-Learn bietet noch weitere Optionen an, wie in der Dokumentation Scikit-Learn → DecisionTreeClassifier() nachgelesen werden kann.

Sowohl bei der Initialisierung des Entscheidungsbaumes können Parameter gesetzt werden, als auch beim Verwenden der verschiedenen Methoden. Tatsächlich haben wir bereits weiter oben aus didaktischen Gründen den Parameter random_state=0 bei der Initialisierung gesetzt, damit immer der gleiche Entscheidungsbaum entsteht. In einem echten Projekt würde dieser Parameter nie verwendet werden.

Experimentieren Sie mit verschiedenen Werten für random_state (z.B. 0, 1, 2, 3). Sie werden feststellen, dass sich die Struktur des Baumes ändern kann, obwohl die Klassifikationsgenauigkeit oft ähnlich bleibt. Testen Sie auch verschiedene Splitting-Kriterien (criterion='gini' vs. criterion='entropy') und vergleichen Sie die entstehenden Bäume.

modell = DecisionTreeClassifier(criterion='entropy', random_state=3)
modell.fit(X,y)

plot_tree(modell, 
    feature_names=['Kilometerstand [km]', 'Preis [EUR]'],
    class_names=['nicht verkauft', 'verkauft']);
../_images/f2b5c0811ad26656160db0008813b03780b77d20ac0cef0ad8d43b286029ad1e.png

Durch die Verwendung von entropy als Kriterium kann sich die Struktur des Entscheidungsbaums ändern. Statt gini=... wird nun entropy=... im Diagramm angezeigt. Die grundlegende Funktionsweise bleibt jedoch gleich: Der Algorithmus wählt die Splits, die die Unreinheit am stärksten reduzieren.

Mini-Übung

Gegeben sind folgende Trainingsdaten für eine Pflanzenklassifikation:

pflanzen_daten = pd.DataFrame({
    'Blattlaenge [cm]': [2.5, 4.9, 6.3, 7.1, 3.2],
    'Blattbreite [cm]': [0.8, 1.5, 2.1, 2.5, 1.0],
    'giftig': [False, False, True, True, False]
})
  1. Trainieren Sie einen Entscheidungsbaum mit random_state=42.

  2. Visualisieren Sie den Baum mit aussagekräftigen Feature- und Klassennamen.

# Hier Ihr Code

Zusammenfassung und Ausblick#

In diesem Kapitel haben wir das Training von Entscheidungsbäumen mit Hilfe der Bibliothek Scikit-Learn vertieft. Im nächsten Kapitel widmen wir uns den Vor-, aber auch den Nachteilen von Entscheidungsbäumen.